{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1Pi_B2cvdBiW"
      },
      "source": [
        "##### Copyright 2021 The TF-Agents Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "nQnmcm0oI1Q-"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GiI8CZYWcJ5n"
      },
      "source": [
        "# Networks\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/agents/tutorials/8_networks_tutorial\">\n",
        "    <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
        "    View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/agents/blob/master/docs/tutorials/8_networks_tutorial.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/agents/blob/master/docs/tutorials/8_networks_tutorial.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/agents/docs/tutorials/8_networks_tutorial.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "31uij8nIo5bG"
      },
      "source": [
        "## Introduction\n",
        "\n",
        "In this colab we will cover how to define custom networks for your agents. The networks help us define the model that is trained by agents. In TF-Agents you will find several different types of networks which are useful across agents:\n",
        "\n",
        "**Main Networks**\n",
        "\n",
        "* **QNetwork**: Used in Qlearning for environments with discrete actions, this network maps an observation to value estimates for each possible action.\n",
        "* **CriticNetworks**: Also referred to as `ValueNetworks` in literature, learns to estimate some version of a Value function mapping some state into an estimate for the expected return of a policy. These networks estimate how good the state the agent is currently in is.\n",
        "* **ActorNetworks**: Learn a mapping from observations to actions. These networks are usually used by our policies to generate actions.\n",
        "* **ActorDistributionNetworks**: Similar to `ActorNetworks` but these generate a distribution which a policy can then sample to generate actions.\n",
        "\n",
        "**Helper Networks**\n",
        "* **EncodingNetwork**: Allows users to easily define a mapping of pre-processing layers to apply to a network's input.\n",
        "* **DynamicUnrollLayer**: Automatically resets the network's state on episode boundaries as it is applied over a time sequence.\n",
        "* **ProjectionNetwork**: Networks like `CategoricalProjectionNetwork` or `NormalProjectionNetwork` take inputs and generate the required parameters to generate Categorical, or Normal distributions.\n",
        "\n",
        "All examples in TF-Agents come with pre-configured networks. However these networks are not setup to handle complex observations.\n",
        "\n",
        "If you have an environment which exposes more than one observation/action and you need to customize your networks then this tutorial is for you!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wmk1GBT9cPqC"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uhGeL1Kpc3Pw"
      },
      "source": [
        "If you haven't installed tf-agents yet, run:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xsLTHlVdiZP3"
      },
      "outputs": [],
      "source": [
        "!pip install tf-agents"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sdvop99JlYSM"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import abc\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "\n",
        "from tf_agents.environments import random_py_environment\n",
        "from tf_agents.environments import tf_py_environment\n",
        "from tf_agents.networks import encoding_network\n",
        "from tf_agents.networks import network\n",
        "from tf_agents.networks import utils\n",
        "from tf_agents.specs import array_spec\n",
        "from tf_agents.utils import common as common_utils\n",
        "from tf_agents.utils import nest_utils\n",
        "\n",
        "tf.compat.v1.enable_v2_behavior()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ums84-YP_21F"
      },
      "source": [
        "## Defining Networks\n",
        "\n",
        "### Network API\n",
        "\n",
        "In TF-Agents we subclass from Keras [Networks](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/engine/network.py). With it we can:\n",
        "\n",
        "* Simplify copy operations required when creating target networks.\n",
        "* Perform automatic variable creation when calling `network.variables()`.\n",
        "* Validate inputs based on network input_specs.\n",
        "\n",
        "##EncodingNetwork\n",
        "As mentioned above the `EncodingNetwork` allows us to easily define a mapping of pre-processing layers to apply to a network's input to generate some encoding.\n",
        "\n",
        "The EncodingNetwork is composed of the following mostly optional layers:\n",
        "\n",
        "  * Preprocessing layers\n",
        "  * Preprocessing combiner\n",
        "  * Conv2D \n",
        "  * Flatten\n",
        "  * Dense \n",
        "\n",
        "The special thing about encoding networks is that input preprocessing is applied. Input preprocessing is possible via `preprocessing_layers` and `preprocessing_combiner` layers.  Each of these can be specified as a nested structure. If the `preprocessing_layers` nest is shallower than `input_tensor_spec`, then the layers will get the subnests. For example, if:\n",
        "\n",
        "```\n",
        "input_tensor_spec = ([TensorSpec(3)] * 2, [TensorSpec(3)] * 5)\n",
        "preprocessing_layers = (Layer1(), Layer2())\n",
        "```\n",
        "\n",
        "then preprocessing will call:\n",
        "\n",
        "```\n",
        "preprocessed = [preprocessing_layers[0](observations[0]),\n",
        "                preprocessing_layers[1](observations[1])]\n",
        "```\n",
        "\n",
        "However if\n",
        "\n",
        "```\n",
        "preprocessing_layers = ([Layer1() for _ in range(2)],\n",
        "                        [Layer2() for _ in range(5)])\n",
        "```\n",
        "\n",
        "then preprocessing will call:\n",
        "\n",
        "```python\n",
        "preprocessed = [\n",
        "  layer(obs) for layer, obs in zip(flatten(preprocessing_layers),\n",
        "                                    flatten(observations))\n",
        "]\n",
        "```\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RP3H1bw0ykro"
      },
      "source": [
        "### Custom Networks\n",
        "\n",
        "To create your own networks you will only have to override the `__init__` and `call` methods. Let's create a custom network using what we learned about `EncodingNetworks` to create an ActorNetwork that takes observations which contain an image and a vector.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zp0TjAJhYo4s"
      },
      "outputs": [],
      "source": [
        "class ActorNetwork(network.Network):\n",
        "\n",
        "  def __init__(self,\n",
        "               observation_spec,\n",
        "               action_spec,\n",
        "               preprocessing_layers=None,\n",
        "               preprocessing_combiner=None,\n",
        "               conv_layer_params=None,\n",
        "               fc_layer_params=(75, 40),\n",
        "               dropout_layer_params=None,\n",
        "               activation_fn=tf.keras.activations.relu,\n",
        "               enable_last_layer_zero_initializer=False,\n",
        "               name='ActorNetwork'):\n",
        "    super(ActorNetwork, self).__init__(\n",
        "        input_tensor_spec=observation_spec, state_spec=(), name=name)\n",
        "\n",
        "    # For simplicity we will only support a single action float output.\n",
        "    self._action_spec = action_spec\n",
        "    flat_action_spec = tf.nest.flatten(action_spec)\n",
        "    if len(flat_action_spec) > 1:\n",
        "      raise ValueError('Only a single action is supported by this network')\n",
        "    self._single_action_spec = flat_action_spec[0]\n",
        "    if self._single_action_spec.dtype not in [tf.float32, tf.float64]:\n",
        "      raise ValueError('Only float actions are supported by this network.')\n",
        "\n",
        "    kernel_initializer = tf.keras.initializers.VarianceScaling(\n",
        "        scale=1. / 3., mode='fan_in', distribution='uniform')\n",
        "    self._encoder = encoding_network.EncodingNetwork(\n",
        "        observation_spec,\n",
        "        preprocessing_layers=preprocessing_layers,\n",
        "        preprocessing_combiner=preprocessing_combiner,\n",
        "        conv_layer_params=conv_layer_params,\n",
        "        fc_layer_params=fc_layer_params,\n",
        "        dropout_layer_params=dropout_layer_params,\n",
        "        activation_fn=activation_fn,\n",
        "        kernel_initializer=kernel_initializer,\n",
        "        batch_squash=False)\n",
        "\n",
        "    initializer = tf.keras.initializers.RandomUniform(\n",
        "        minval=-0.003, maxval=0.003)\n",
        "\n",
        "    self._action_projection_layer = tf.keras.layers.Dense(\n",
        "        flat_action_spec[0].shape.num_elements(),\n",
        "        activation=tf.keras.activations.tanh,\n",
        "        kernel_initializer=initializer,\n",
        "        name='action')\n",
        "\n",
        "  def call(self, observations, step_type=(), network_state=()):\n",
        "    outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec)\n",
        "    # We use batch_squash here in case the observations have a time sequence\n",
        "    # compoment.\n",
        "    batch_squash = utils.BatchSquash(outer_rank)\n",
        "    observations = tf.nest.map_structure(batch_squash.flatten, observations)\n",
        "\n",
        "    state, network_state = self._encoder(\n",
        "        observations, step_type=step_type, network_state=network_state)\n",
        "    actions = self._action_projection_layer(state)\n",
        "    actions = common_utils.scale_to_spec(actions, self._single_action_spec)\n",
        "    actions = batch_squash.unflatten(actions)\n",
        "    return tf.nest.pack_sequence_as(self._action_spec, [actions]), network_state"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fm-MbMMLYiZj"
      },
      "source": [
        "Let's create a `RandomPyEnvironment` to generate structured observations and validate our implementation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E2XoNuuD66s5"
      },
      "outputs": [],
      "source": [
        "action_spec = array_spec.BoundedArraySpec((3,), np.float32, minimum=0, maximum=10)\n",
        "observation_spec =  {\n",
        "    'image': array_spec.BoundedArraySpec((16, 16, 3), np.float32, minimum=0,\n",
        "                                        maximum=255),\n",
        "    'vector': array_spec.BoundedArraySpec((5,), np.float32, minimum=-100,\n",
        "                                          maximum=100)}\n",
        "\n",
        "random_env = random_py_environment.RandomPyEnvironment(observation_spec, action_spec=action_spec)\n",
        "\n",
        "# Convert the environment to a TFEnv to generate tensors.\n",
        "tf_env = tf_py_environment.TFPyEnvironment(random_env)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LM3uDTD7TNVx"
      },
      "source": [
        "Since we've defined the observations to be a dict we need to create preprocessing layers to handle these."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r9U6JVevTAJw"
      },
      "outputs": [],
      "source": [
        "preprocessing_layers = {\n",
        "    'image': tf.keras.models.Sequential([tf.keras.layers.Conv2D(8, 4),\n",
        "                                        tf.keras.layers.Flatten()]),\n",
        "    'vector': tf.keras.layers.Dense(5)\n",
        "    }\n",
        "preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)\n",
        "actor = ActorNetwork(tf_env.observation_spec(), \n",
        "                     tf_env.action_spec(),\n",
        "                     preprocessing_layers=preprocessing_layers,\n",
        "                     preprocessing_combiner=preprocessing_combiner)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mM9qedlwc41U"
      },
      "source": [
        "Now that we have the actor network we can process observations from the environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JOkkeu7vXoei"
      },
      "outputs": [],
      "source": [
        "time_step = tf_env.reset()\n",
        "actor(time_step.observation, time_step.step_type)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ALGxaQLWc9GI"
      },
      "source": [
        "This same strategy can be used to customize any of the main networks used by the agents. You can define whatever preprocessing and connect it to the rest of the network. As you define your own custom make sure the output layer definitions of the network match."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "8_networks_tutorial.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
